{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33m8christmas8\u001b[0m (use `wandb login --relogin` to force relogin)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "                    Syncing run <strong><a href=\"https://wandb.ai/8christmas8/rdrop_demo/runs/3bzygglj\" target=\"_blank\">bert_base</a></strong> to <a href=\"https://wandb.ai/8christmas8/rdrop_demo\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">docs</a>).<br/>\n",
       "\n",
       "                "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from pandas import DataFrame\n",
    "import numpy as np\n",
    "from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report\n",
    "import torch\n",
    "import os\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "from typing import Counter\n",
    "import pickle\n",
    "from transformers import TrainingArguments\n",
    "from transformers import Trainer\n",
    "from torch.nn import CrossEntropyLoss\n",
    "import torch.nn.functional as F\n",
    "from transformers import BertTokenizer, BertForSequenceClassification\n",
    "from transformers import EarlyStoppingCallback\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from imblearn.over_sampling import RandomOverSampler\n",
    "from sklearn.model_selection import train_test_split\n",
    "from typing import Optional, List\n",
    "import wandb\n",
    "wandb.init(project=\"rdrop_demo\", \n",
    "           name=\"bert_base\",\n",
    "           tags=[\"baseline\"],\n",
    "           group=\"bert\")\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"0\"\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "data_source_path = \"/data/tmp/nlp-data/open_source_data/classification/ChnSentiCorp_htl_all.csv\"\n",
    "output_path = \"/data/christmas.wang/project/project_demo/output/classification/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>label</th>\n",
       "      <th>review</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1</td>\n",
       "      <td>CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   label                                             review\n",
       "0      1  距离川沙公路较近,但是公交指示不对,如果是\"蔡陆线\"的话,会非常麻烦.建议用别的路线.房间较...\n",
       "1      1                       商务大床房，房间很大，床有2M宽，整体感觉经济实惠不错!\n",
       "2      1         早餐太差，无论去多少人，那边也不加食品的。酒店应该重视一下这个问题了。房间本身很好。\n",
       "3      1  宾馆在小街道上，不大好找，但还好北京热心同胞很多~宾馆设施跟介绍的差不多，房间很小，确实挺小...\n",
       "4      1               CBD中心,周围没什么店铺,说5星有点勉强.不知道为什么卫生间没有电吹风"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_source_df = pd.read_csv(data_source_path)\n",
    "data_source_df.dropna(how=\"any\", axis=0, inplace=True)\n",
    "data_source_df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    5322\n",
       "0    2443\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = data_source_df[\"label\"].value_counts()\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train contains [6212] records & Test contain s [1553] records\n"
     ]
    }
   ],
   "source": [
    "train_df, test_df = train_test_split(data_source_df, test_size=0.2, random_state=42, shuffle=True)\n",
    "print(\"Train contains [{}] records & Test contain s [{}] records\".format(train_df.shape[0], test_df.shape[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_metrics(p):\n",
    "    pred, labels = p\n",
    "    pred = np.argmax(pred, axis=1)\n",
    "\n",
    "    accuracy = accuracy_score(y_true=labels, y_pred=pred)\n",
    "    recall = recall_score(y_true=labels, y_pred=pred, average='binary')\n",
    "    precision = precision_score(y_true=labels, y_pred=pred, average='binary')\n",
    "    f1 = f1_score(y_true=labels, y_pred=pred, average='binary')\n",
    "\n",
    "    return {\"accuracy\": accuracy, \"precision\": precision, \"recall\": recall, \"f1\": f1}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MultilabelTrainer(Trainer):\n",
    "    \n",
    "    def compute_kl_loss(self, p, q):\n",
    "    \n",
    "        p_loss = F.kl_div(F.log_softmax(p, dim=-1), F.softmax(q, dim=-1), reduction='none')\n",
    "        q_loss = F.kl_div(F.log_softmax(q, dim=-1), F.softmax(p, dim=-1), reduction='none')\n",
    "\n",
    "        # You can choose whether to use function \"sum\" and \"mean\" depending on your task\n",
    "        p_loss = p_loss.sum()\n",
    "        q_loss = q_loss.sum()\n",
    "\n",
    "        loss = (p_loss + q_loss) / 2\n",
    "        return loss\n",
    "\n",
    "    def compute_loss(self, model, inputs, return_outputs=False):\n",
    "        \"\"\"\n",
    "        How the loss is computed by Trainer. By default, all models return the loss in the first element.\n",
    "\n",
    "        Subclass and override for custom behavior.\n",
    "        \"\"\"\n",
    "        labels = inputs.get(\"labels\")\n",
    "        outputs_1 = model(**inputs)\n",
    "        outputs_2 = model(**inputs)\n",
    "\n",
    "        \n",
    "        logits_1 = outputs_1.get('logits')\n",
    "        logits_2 = outputs_2.get('logits')\n",
    "        loss_fct = CrossEntropyLoss()\n",
    "        \n",
    "        # cross entropy loss for classifier\n",
    "        ce_loss = 0.5 * (loss_fct(logits_1, labels) + loss_fct(logits_2, labels))\n",
    "        kl_loss = self.compute_kl_loss(logits_1, logits_2)\n",
    "\n",
    "        # carefully choose hyper-parameters\n",
    "        loss = ce_loss + 4 * kl_loss\n",
    "        return (loss, outputs_1) if return_outputs else loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def data_oversampling(x: Optional[List], y: Optional[List]):\n",
    "    x = np.array(x).reshape(-1, 1)\n",
    "    ros = RandomOverSampler(random_state=0)\n",
    "    x, y = ros.fit_resample(x, y)\n",
    "    x = [ele[0] for ele in x]\n",
    "    return x, y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create torch dataset\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, encodings, labels=None):\n",
    "        self.encodings = encodings\n",
    "        self.labels = labels\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}\n",
    "        if self.labels:\n",
    "            item[\"labels\"] = torch.tensor(self.labels[idx])\n",
    "        return item\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.encodings[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def classification_report_csv(report, path):\n",
    "    report_data = []\n",
    "    lines = report.split('\\n')\n",
    "    for line in lines[2:-5]:\n",
    "        row = {}\n",
    "        row_data = line.split('      ')\n",
    "        row['class'] = row_data[1]\n",
    "        row['precision'] = float(row_data[2].strip())\n",
    "        row['recall'] = float(row_data[3].strip())\n",
    "        row['f1_score'] = float(row_data[4].strip())\n",
    "        row['support'] = float(row_data[5].strip())\n",
    "        report_data.append(row)\n",
    "    dataframe = pd.DataFrame.from_dict(report_data)\n",
    "    dataframe.to_csv(path, index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(df: DataFrame, rdrop: False):\n",
    "    # Define pretrained tokenizer and model\n",
    "    tokenizer = BertTokenizer.from_pretrained('/data/tmp/christmas.wang/chinese_wwm_ext_pytorch/', do_lower_case=False)\n",
    "    model = BertForSequenceClassification.from_pretrained(\"/data/tmp/christmas.wang/chinese_wwm_ext_pytorch/\", num_labels=2, hidden_dropout_prob=0.3)\n",
    "    model = model.to(device)\n",
    "\n",
    "    # ********************** Data Process **********************\n",
    "    x = []\n",
    "    y = []\n",
    "    for index, row in tqdm(df.iterrows(), total=df.shape[0]):\n",
    "        text = str(row[\"review\"]).strip()\n",
    "        if len(text) < 1:\n",
    "            continue\n",
    "        x.append(text)\n",
    "        y.append(row[\"label\"])\n",
    "    # Label Encode\n",
    "    le = LabelEncoder()\n",
    "    le.fit(y)\n",
    "    y = list(le.transform(y))\n",
    "    output = open(os.path.join(output_path, \"output_encoder.pkl\"), 'wb')\n",
    "    pickle.dump(le, output) \n",
    "    print(dict(Counter(y)))\n",
    "    \n",
    "    x_train, x_val, y_train, y_val = train_test_split(x, y, test_size=0.2)\n",
    "    x_train_tokenized = tokenizer(x_train, padding=True, truncation=True, max_length=512)\n",
    "    x_val_tokenized = tokenizer(x_val, padding=True, truncation=True, max_length=512)\n",
    "    train_dataset = Dataset(x_train_tokenized, y_train)\n",
    "    val_dataset = Dataset(x_val_tokenized, y_val)\n",
    "\n",
    "    # ********************** train **********************\n",
    "    # Define Trainer\n",
    "    args = TrainingArguments(\n",
    "        report_to=\"wandb\",\n",
    "        output_dir=output_path,\n",
    "        evaluation_strategy=\"steps\",\n",
    "        eval_steps=500,\n",
    "        per_device_train_batch_size=16,\n",
    "        per_device_eval_batch_size=16,\n",
    "        learning_rate=2e-5,\n",
    "        num_train_epochs=6,\n",
    "        seed=0,\n",
    "        load_best_model_at_end=True,\n",
    "    )\n",
    "\n",
    "    trainer = Trainer(\n",
    "        model=model,\n",
    "        args=args,\n",
    "        train_dataset=train_dataset,\n",
    "        eval_dataset=val_dataset,\n",
    "        compute_metrics=compute_metrics,\n",
    "        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
    "    )\n",
    "    \n",
    "    if rdrop:\n",
    "        trainer = MultilabelTrainer(\n",
    "            model=model,\n",
    "            args=args,\n",
    "            train_dataset=train_dataset,\n",
    "            eval_dataset=val_dataset,\n",
    "            compute_metrics=compute_metrics,\n",
    "            callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],\n",
    "        )\n",
    "\n",
    "    trainer.train()\n",
    "    model_dir = os.path.join(output_path, \"model_save\")\n",
    "    model_to_save = model.module if hasattr(model, 'module') else model\n",
    "    torch.save(model_to_save.state_dict(), os.path.join(model_dir, \"pytorch_model.bin\"))\n",
    "    model_to_save.config.to_json_file(os.path.join(model_dir, \"config.json\"))\n",
    "    tokenizer.save_vocabulary(os.path.join(model_dir, \"vocab.txt\"))\n",
    "\n",
    "    # ********************** Predict **********************\n",
    "    x_test = list(test_df[\"review\"])\n",
    "    test_df['label_id'] = le.transform(test_df[\"label\"].tolist())\n",
    "    y_true = list(test_df['label_id'])\n",
    "    x_test_tokenized = tokenizer(x_test, padding=True, truncation=True, max_length=512)\n",
    "\n",
    "    # Create torch dataset\n",
    "    test_dataset = Dataset(x_test_tokenized)\n",
    "\n",
    "    model = BertForSequenceClassification.from_pretrained(model_dir, num_labels=2)\n",
    "    # Define test trainer\n",
    "    test_trainer = Trainer(model)\n",
    "\n",
    "    # Make prediction\n",
    "    raw_pred, _, _ = test_trainer.predict(test_dataset)\n",
    "\n",
    "    # Preprocess raw predictions\n",
    "    y_pred = np.argmax(raw_pred, axis=1)\n",
    "    report = classification_report(y_true, y_pred)\n",
    "    classification_report_csv(report, output_path+\"report.csv\")\n",
    "    print(report)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Didn't find file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/added_tokens.json. We won't load it.\n",
      "Didn't find file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/special_tokens_map.json. We won't load it.\n",
      "Didn't find file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/tokenizer_config.json. We won't load it.\n",
      "Didn't find file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/tokenizer.json. We won't load it.\n",
      "loading file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/vocab.txt\n",
      "loading file None\n",
      "loading file None\n",
      "loading file None\n",
      "loading file None\n",
      "loading configuration file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/config.json\n",
      "loading configuration file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/config.json\n",
      "Model config BertConfig {\n",
      "  \"attention_probs_dropout_prob\": 0.1,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"directionality\": \"bidi\",\n",
      "  \"gradient_checkpointing\": false,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.3,\n",
      "  \"hidden_size\": 768,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 3072,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 512,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 12,\n",
      "  \"num_hidden_layers\": 12,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"pooler_fc_size\": 768,\n",
      "  \"pooler_num_attention_heads\": 12,\n",
      "  \"pooler_num_fc_layers\": 3,\n",
      "  \"pooler_size_per_head\": 128,\n",
      "  \"pooler_type\": \"first_token_transform\",\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.10.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 21128\n",
      "}\n",
      "\n",
      "loading weights file /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/pytorch_model.bin\n",
      "Some weights of the model checkpoint at /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/ were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /data/tmp/christmas.wang/chinese_wwm_ext_pytorch/ and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6212/6212 [00:00<00:00, 21344.62it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{1: 4274, 0: 1938}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "PyTorch: setting up devices\n",
      "***** Running training *****\n",
      "  Num examples = 4969\n",
      "  Num Epochs = 6\n",
      "  Instantaneous batch size per device = 16\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 16\n",
      "  Gradient Accumulation steps = 1\n",
      "  Total optimization steps = 1866\n",
      "Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='411' max='1866' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [ 411/1866 03:33 < 12:40, 1.91 it/s, Epoch 1.32/6]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "      <th>Validation Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "train(train_df, False)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "d4f9b45e4c5a885cd1a74da0410159ed9246fad19fa0f54d287889467781bee4"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
